"""
Source: https://github.com/KaiyangZhou/deep-person-reid
"""
import torch
from torch.nn import functional as F


def compute_distance_matrix(input1, input2, metric="euclidean"):
    """A wrapper function for computing distance matrix.

    Each input matrix has the shape (n_data, feature_dim).

    Args:
        input1 (torch.Tensor): 2-D feature matrix.
        input2 (torch.Tensor): 2-D feature matrix.
        metric (str, optional): "euclidean" or "cosine".
            Default is "euclidean".

    Returns:
        torch.Tensor: distance matrix.
    """
    # check input
    assert isinstance(input1, torch.Tensor)
    assert isinstance(input2, torch.Tensor)
    assert input1.dim() == 2, "Expected 2-D tensor, but got {}-D".format(
        input1.dim()
    )
    assert input2.dim() == 2, "Expected 2-D tensor, but got {}-D".format(
        input2.dim()
    )
    assert input1.size(1) == input2.size(1)

    if metric == "euclidean":
        distmat = euclidean_squared_distance(input1, input2)
    elif metric == "cosine":
        distmat = cosine_distance(input1, input2)
    else:
        raise ValueError(
            "Unknown distance metric: {}. "
            'Please choose either "euclidean" or "cosine"'.format(metric)
        )

    return distmat


def euclidean_squared_distance(input1, input2):
    """Computes euclidean squared distance.

    Args:
        input1 (torch.Tensor): 2-D feature matrix.
        input2 (torch.Tensor): 2-D feature matrix.

    Returns:
        torch.Tensor: distance matrix.
    """
    m, n = input1.size(0), input2.size(0)
    mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)
    mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
    distmat = mat1 + mat2
    distmat.addmm_(1, -2, input1, input2.t())
    return distmat


def cosine_distance(input1, input2):
    """Computes cosine distance.

    Args:
        input1 (torch.Tensor): 2-D feature matrix.
        input2 (torch.Tensor): 2-D feature matrix.

    Returns:
        torch.Tensor: distance matrix.
    """
    input1_normed = F.normalize(input1, p=2, dim=1)
    input2_normed = F.normalize(input2, p=2, dim=1)
    distmat = 1 - torch.mm(input1_normed, input2_normed.t())
    return distmat
